{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp params\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from nbdev.showdoc import show_doc\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Params\n",
    "\n",
    "`Params` is the major object to control the whole modeling process. It is supposed to be accessable anywhere. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/anaconda3/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)\n",
      "  return torch._C._cuda_getDeviceCount() > 0\n"
     ]
    }
   ],
   "source": [
    "# export\n",
    "\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.embedding_layer.base import (DefaultMultimodalEmbedding,\n",
    "                                       DuplicateAugMultimodalEmbedding)\n",
    "from m3tl.loss_strategy.base import SumLossCombination\n",
    "from m3tl.mtl_model.mmoe import MMoE\n",
    "from m3tl.problem_types import cls as problem_type_cls\n",
    "from m3tl.problem_types import (contrastive_learning, masklm, multi_cls,\n",
    "                                premask_mlm, pretrain, regression, seq_tag,\n",
    "                                vector_fit)\n",
    "\n",
    "\n",
    "class Params(BaseParams):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # register pre-defined problem types\n",
    "        self.register_problem_type(problem_type='cls',\n",
    "                                   top_layer=problem_type_cls.Classification,\n",
    "                                   label_handling_fn=problem_type_cls.cls_label_handling_fn,\n",
    "                                   get_or_make_label_encoder_fn=problem_type_cls.cls_get_or_make_label_encoder_fn,\n",
    "                                   description='Classification')\n",
    "        self.register_problem_type(problem_type='multi_cls',\n",
    "                                   top_layer=multi_cls.MultiLabelClassification,\n",
    "                                   label_handling_fn=multi_cls.multi_cls_label_handling_fn,\n",
    "                                   get_or_make_label_encoder_fn=multi_cls.multi_cls_get_or_make_label_encoder_fn,\n",
    "                                   description='Multi-Label Classification')\n",
    "        self.register_problem_type(problem_type='seq_tag',\n",
    "                                   top_layer=seq_tag.SequenceLabel,\n",
    "                                   label_handling_fn=seq_tag.seq_tag_label_handling_fn,\n",
    "                                   get_or_make_label_encoder_fn=seq_tag.seq_tag_get_or_make_label_encoder_fn,\n",
    "                                   description='Sequence Labeling')\n",
    "        self.register_problem_type(problem_type='masklm',\n",
    "                                   top_layer=masklm.MaskLM,\n",
    "                                   label_handling_fn=masklm.masklm_label_handling_fn,\n",
    "                                   get_or_make_label_encoder_fn=masklm.masklm_get_or_make_label_encoder_fn,\n",
    "                                   description='Masked Language Model')\n",
    "        self.register_problem_type(problem_type='pretrain',\n",
    "                                   top_layer=pretrain.PreTrain,\n",
    "                                   label_handling_fn=pretrain.pretrain_label_handling_fn,\n",
    "                                   get_or_make_label_encoder_fn=pretrain.pretrain_get_or_make_label_encoder_fn,\n",
    "                                   description='NSP+MLM(Deprecated)')\n",
    "        self.register_problem_type(problem_type='regression',\n",
    "                                   top_layer=regression.Regression,\n",
    "                                   label_handling_fn=regression.regression_label_handling_fn,\n",
    "                                   get_or_make_label_encoder_fn=regression.regression_get_or_make_label_encoder_fn,\n",
    "                                   description='Regression')\n",
    "        self.register_problem_type(\n",
    "            problem_type='vector_fit',\n",
    "            top_layer=vector_fit.VectorFit,\n",
    "            label_handling_fn=vector_fit.vector_fit_label_handling_fn,\n",
    "            get_or_make_label_encoder_fn=vector_fit.vector_fit_get_or_make_label_encoder_fn,\n",
    "            description='Vector Fitting')\n",
    "        self.register_problem_type(\n",
    "            problem_type='premask_mlm',\n",
    "            top_layer=premask_mlm.PreMaskMLM,\n",
    "            label_handling_fn=premask_mlm.premask_mlm_label_handling_fn,\n",
    "            get_or_make_label_encoder_fn=premask_mlm.premask_mlm_get_or_make_label_encoder_fn,\n",
    "            description='Pre-masked Masked Language Model'\n",
    "        )\n",
    "        self.register_problem_type(\n",
    "            problem_type='contrastive_learning',\n",
    "            top_layer=contrastive_learning.ContrastiveLearning,\n",
    "            label_handling_fn=contrastive_learning.contrastive_learning_label_handling_fn,\n",
    "            get_or_make_label_encoder_fn=contrastive_learning.contrastive_learning_get_or_make_label_encoder_fn,\n",
    "            description='Contrastive Learning'\n",
    "        )\n",
    "\n",
    "        self.register_mtl_model(\n",
    "            'mmoe', MMoE, include_top=False, extra_info='MMoE')\n",
    "        self.register_loss_combination_strategy('sum', SumLossCombination)\n",
    "        self.register_embedding_layer(\n",
    "            'duplicate_data_augmentation_embedding', DuplicateAugMultimodalEmbedding)\n",
    "        self.register_embedding_layer(\n",
    "            'default_embedding', DefaultMultimodalEmbedding)\n",
    "\n",
    "        self.assign_loss_combination_strategy('sum')\n",
    "        self.assign_data_sampling_strategy()\n",
    "        self.assign_embedding_layer('default_embedding')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "Adding new problem weibo_cws, problem type: seq_tag\n",
      "Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "Adding new problem weibo_fake_cls, problem type: cls\n",
      "Adding new problem weibo_masklm, problem type: masklm\n",
      "Adding new problem weibo_pretrain, problem type: pretrain\n",
      "Adding new problem weibo_fake_regression, problem type: regression\n",
      "Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "Adding new problem weibo_premask_mlm, problem type: premask_mlm\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "from m3tl.test_base import TestBase\n",
    "tb = TestBase()\n",
    "params = tb.params\n",
    "tmp_model_dir = tb.tmpckptdir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Problems\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "# define a simple preprocessing function\n",
    "import m3tl\n",
    "from m3tl import preprocessing_fn\n",
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str):\n",
    "    \"Simple example to demonstrate singe modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = ['this is a toy input' for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = ['this is a toy input for test' for _ in range(10)]\n",
    "        toy_target = ['a' for _ in range(10)]\n",
    "    return toy_input, toy_target\n",
    "\n",
    "@preprocessing_fn\n",
    "def toy_seq_tag(params: Params, mode: str):\n",
    "    \"Simple example to demonstrate singe modal tuple of list return\"\n",
    "    if mode == m3tl.TRAIN:\n",
    "        toy_input = ['this is a toy input'.split(' ') for _ in range(10)]\n",
    "        toy_target = [['a', 'b', 'c', 'd', 'e'] for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = ['this is a toy input for test'.split(' ') for _ in range(10)]\n",
    "        toy_target = [['a', 'b', 'c', 'd', 'e', 'e', 'e'] for _ in range(10)]\n",
    "    return toy_input, toy_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.register_problem\" class=\"doc_header\"><code>BaseParams.register_problem</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L474\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.register_problem</code>(**`problem_name`**:`str`, **`problem_type`**=*`'cls'`*, **`processing_fn`**:`Callable`=*`None`*)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.register_problem)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add problems.\n",
    "\n",
    "Args:\n",
    "- problem_name (str): problem name.\n",
    "- problem_type (str, optional): One of the following problem types:\n",
    "['cls', 'seq_tag', 'seq2seq_tag', 'seq2seq_text', 'multi_cls', 'pretrain'].\n",
    "Defaults to 'cls'.\n",
    "- processing_fn (Callable, optional): preprocessing function. Defaults to None.\n",
    "\n",
    "Raises:\n",
    "- ValueError: unexpected problem_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params.register_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)\n",
    "params.register_problem(problem_name='toy_seq_tag', problem_type='seq_tag', processing_fn=toy_seq_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.register_multiple_problems\" class=\"doc_header\"><code>BaseParams.register_multiple_problems</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L484\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.register_multiple_problems</code>(**`problem_type_dict`**:`Dict`\\[`str`, `str`\\], **`processing_fn_dict`**:`Dict`\\[`str`, `Callable`\\]=*`None`*)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.register_multiple_problems)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add multiple problems.\n",
    "\n",
    "processing_fn_dict is optional, if it's not provided, processing fn will be set as None.\n",
    "\n",
    "Args:\n",
    "- problem_type_dict (Dict[str, str]): problem type dict\n",
    "- processing_fn_dict (Dict[str, Callable], optional): problem type fn. Defaults to None."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding new problem toy_cls, problem type: cls\n",
      "Adding new problem toy_seq_tag, problem type: seq_tag\n"
     ]
    }
   ],
   "source": [
    "# make dict and add problems to params\n",
    "problem_type_dict = {'toy_cls': 'cls', 'toy_seq_tag': 'seq_tag'}\n",
    "processing_fn_dict = {'toy_cls': toy_cls, 'toy_seq_tag': toy_seq_tag}\n",
    "params.register_multiple_problems(problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assign Problems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.assign_problem\" class=\"doc_header\"><code>BaseParams.assign_problem</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L545\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.assign_problem</code>(**`flag_string`**:`str`, **`gpu`**=*`2`*, **`base_dir`**:`str`=*`None`*, **`dir_name`**:`str`=*`None`*, **`predicting`**=*`False`*)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.assign_problem)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assign the actual run problem to param. This function will\n",
    "do the following things:\n",
    "\n",
    "1. parse the flag string to form the run_problem_list\n",
    "2. create checkpoint saving path\n",
    "3. calculate total number of training data and training steps\n",
    "4. scale learning rate with the number of gpu linearly\n",
    "\n",
    "Arguments:\n",
    "- flag_string {str} -- run problem string\n",
    "- example: cws|POS|weibo_ner&weibo_cws\n",
    "\n",
    "Keyword Arguments:\n",
    "- gpu {int} -- number of gpu use for training, this will affect the training steps and learning rate (default: {2})\n",
    "- base_dir {str} -- base dir for ckpt, if None, then \"models\" is assigned (default: {None})\n",
    "- dir_name {str} -- dir name for ckpt, if None, will be created automatically (default: {None})\n",
    "- predicting {bool} -- whether is predicting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "params.assign_problem(flag_string='toy_seq_tag|toy_cls', base_dir=tmp_model_dir)\n",
    "assert params.problem_assigned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After problem assigned, the model path should be created with tokenizers, label encoder files in it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "# assert os.listdir(params.ckpt_dir) == ['data_info.json',\n",
    "#  'tokenizer',\n",
    "#  'toy_cls_label_encoder.pkl',\n",
    "#  'toy_seq_tag_label_encoder.pkl',\n",
    "#  'bert_config']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Register new problem type\n",
    "\n",
    "You can also implement your own problem type. Essentially, a problem type has:\n",
    "- name\n",
    "- top layer\n",
    "- label handling function\n",
    "- label encoder creating function\n",
    "\n",
    "Here we register a vector fitting(vector annealing) problem type as an example.\n",
    "\n",
    "Note: This is originally designed as an internal API for development. So it's not user-friendly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.register_problem_type\" class=\"doc_header\"><code>BaseParams.register_problem_type</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L445\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.register_problem_type</code>(**`problem_type`**:`str`, **`top_layer`**:`Model`=*`None`*, **`label_handling_fn`**:`Callable`=*`None`*, **`get_or_make_label_encoder_fn`**:`Callable`=*`None`*, **`inherit_from`**:`str`=*`None`*)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.register_problem_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "API to register a new problem type\n",
    "\n",
    "Args:\n",
    "- problem_type: string, problem type name\n",
    "- top_layer: a keras model with some specific reqirements\n",
    "- label_handling_fn: function to convert labels to label ids\n",
    "- get_or_make_label_encoder_fn: function to create label encoder, num_classes has to be specified here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from m3tl.problem_types.utils import BaseTop\n",
    "from m3tl.problem_types.utils import empty_tensor_handling_loss, nan_loss_handling\n",
    "import tensorflow as tf\n",
    "from typing import Tuple, Dict\n",
    "import numpy as np\n",
    "# top layer\n",
    "class VectorFit(BaseTop):\n",
    "    def __init__(self, params: Params, problem_name: str) -> None:\n",
    "        super(VectorFit, self).__init__(\n",
    "            params=params, problem_name=problem_name)\n",
    "        self.num_classes = self.params.num_classes[problem_name]\n",
    "        self.dense = tf.keras.layers.Dense(self.num_classes)\n",
    "\n",
    "    def call(self, inputs: Tuple[Dict], mode: str):\n",
    "        feature, hidden_feature = inputs\n",
    "        pooled_hidden = hidden_feature['pooled']\n",
    "\n",
    "        logits = self.dense(pooled_hidden)\n",
    "        if mode != tf.estimator.ModeKeys.PREDICT:\n",
    "            # this is the same as the label_id returned by vector_fit_label_handling_fn\n",
    "            label = feature['{}_label_ids'.format(self.problem_name)]\n",
    "\n",
    "            loss = empty_tensor_handling_loss(label, logits, cosine_wrapper)\n",
    "            loss = nan_loss_handling(loss)\n",
    "            self.add_loss(loss)\n",
    "\n",
    "            self.add_metric(tf.math.negative(\n",
    "                loss), name='{}_cos_sim'.format(self.problem_name), aggregation='mean')\n",
    "        return logits\n",
    "\n",
    "# label handling fn\n",
    "def vector_fit_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None):\n",
    "    # don't need to encoder labels, return array directly\n",
    "    # return label_id and label mask\n",
    "    label_id = np.array(target, dtype='float32')\n",
    "    return label_id, None\n",
    "\n",
    "# make label encoder\n",
    "def vector_fit_get_or_make_label_encoder_fn(params: Params, problem, mode, label_list):\n",
    "    # don't need to make label encoder here\n",
    "    # set params num_classes for this problem\n",
    "    label_array = np.array(label_list)\n",
    "    params.num_classes[problem] = label_array.shape[-1]\n",
    "    return None\n",
    "\n",
    "params.register_problem_type(problem_type='vectorfit', top_layer=VectorFit, label_handling_fn=vector_fit_label_handling_fn, get_or_make_label_encoder_fn=vector_fit_get_or_make_label_encoder_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.from_json\" class=\"doc_header\"><code>BaseParams.from_json</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L205\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.from_json</code>(**`json_path`**:`str`=*`None`*)\n",
       "\n",
       "Load json file as params.\n",
       "\n",
       "json_path could not be None if the problem is not assigned to params\n",
       "\n",
       "Args:\n",
       "    json_path (str, optional): Path to json file. Defaults to None.\n",
       "\n",
       "Raises:\n",
       "    AttributeError"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.from_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.to_json\" class=\"doc_header\"><code>BaseParams.to_json</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L191\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.to_json</code>()\n",
       "\n",
       "Save the params as json files. Please note that processing_fn is not saved.\n",
       "        "
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.to_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.parse_problem_string\" class=\"doc_header\"><code>BaseParams.parse_problem_string</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L239\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.parse_problem_string</code>(**`flag_string`**:`str`)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.parse_problem_string)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Parse problem string\n",
    "\n",
    "Arguments: flag_string {str} -- problem string\n",
    "\n",
    "Returns: list -- problem list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chained with |:  (['toy_cls', 'toy_seq_tag'], [['toy_seq_tag'], ['toy_cls']])\n",
      "chained with &:  (['toy_cls', 'toy_seq_tag'], [['toy_seq_tag', 'toy_cls']])\n"
     ]
    }
   ],
   "source": [
    "print('chained with |: ', params.parse_problem_string('toy_seq_tag|toy_cls'))\n",
    "print('chained with &: ', params.parse_problem_string('toy_seq_tag&toy_cls'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.get_problem_type\" class=\"doc_header\"><code>BaseParams.get_problem_type</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L411\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.get_problem_type</code>(**`problem`**:`str`)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.get_problem_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'seq_tag'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params.get_problem_type('toy_seq_tag')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.update_train_steps\" class=\"doc_header\"><code>BaseParams.update_train_steps</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L414\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.update_train_steps</code>(**`train_steps_per_epoch`**:`int`, **`epoch`**:`int`=*`None`*, **`warmup_ratio`**=*`0.1`*)\n",
       "\n",
       "If the batch_size is dynamic, we have to loop through the tf.data.Dataset\n",
       "to get the accurate number of training steps. In this case, we need a function to\n",
       "update the train_steps which will be used to calculate learning rate schedule.\n",
       "\n",
       "WARNING: updating should be called before the model is compiled!\n",
       "\n",
       "Args:\n",
       "    train_steps (int): new number of train_steps"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.update_train_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the batch_size is dynamic, we have to loop through the tf.data.Dataset\n",
    "to get the accurate number of training steps. In this case, we need a function to\n",
    "update the train_steps which will be used to calculate learning rate schedule.\n",
    "\n",
    "WARNING: updating should be called before the model is compiled! \n",
    "\n",
    "Args:\n",
    "- train_steps (int): new number of train_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1500 150\n"
     ]
    }
   ],
   "source": [
    "\n",
    "params.update_train_steps(train_steps_per_epoch=100)\n",
    "print(params.train_steps, params.num_warmup_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"BaseParams.assign_data_sampling_strategy\" class=\"doc_header\"><code>BaseParams.assign_data_sampling_strategy</code><a href=\"https://github.com/JayYip/m3tl/tree/master/m3tl/base_params.py#L568\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>BaseParams.assign_data_sampling_strategy</code>(**`sampling_strategy_name`**=*`'data_balanced'`*, **`sampling_strategy_fn`**:`Callable`=*`None`*)\n",
       "\n"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(BaseParams.assign_data_sampling_strategy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set data sampling strategy for multi-task learning.\n",
    "\n",
    "'data_balanced' and 'problem_balanced' is implemented by default.\n",
    "data_balanced: sampling weight equals to number of rows of that problem chunk.\n",
    "problem_balanced: sampling weight equals to 1 for every problem chunk.\n",
    "\n",
    "Args:\n",
    "- sampling_strategy (str, optional): sampling strategy. Defaults to 'data_balanced'.\n",
    "- sampling_strategy_fn (Callable, optional): function to create weight dict. Defaults to None.\n",
    "\n",
    "Raises:\n",
    "- NotImplementedError: sampling_strategy_fn is not implemented yet\n",
    "- ValueError: invalid sampling_strategy provided\n",
    "\n",
    "Returns:\n",
    "- Dict[str, float]: sampling weight for each problem_chunk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params.assign_data_sampling_strategy(sampling_strategy_name='problem_balanced')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.3 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
